{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a18a22a9",
   "metadata": {},
   "source": [
    "# Constrained Atom and Bond Prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "660f198a",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/constrained_mol_atom_bond.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "14c413f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install chemprop from GitHub if running in Google Colab\n",
    "import os\n",
    "\n",
    "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "    try:\n",
    "        import chemprop\n",
    "    except ImportError:\n",
    "        !git clone https://github.com/chemprop/chemprop.git\n",
    "        %cd chemprop\n",
    "        !pip install .\n",
    "        %cd examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e816bc13",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "from pathlib import Path\n",
    "\n",
    "from lightning import pytorch as pl\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from chemprop import data, featurizers, models, nn\n",
    "\n",
    "\n",
    "chemprop_dir = Path.cwd().parent\n",
    "data_dir = chemprop_dir / \"tests\" / \"data\" / \"mol_atom_bond\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3919ffe3",
   "metadata": {},
   "source": [
    "If any of the atom or bond properties should sum to a known molecule level value, we can constrain the atom and bond predictions to sum to that value. For example, atom partial charges should sum to the total charge of the molecule."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63d73daf",
   "metadata": {},
   "source": [
    "## Make datapoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "524a494a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>mol_y</th>\n",
       "      <th>atom_y1</th>\n",
       "      <th>atom_y2</th>\n",
       "      <th>bond_y1</th>\n",
       "      <th>bond_y2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[H][H]</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0]</td>\n",
       "      <td>[1.008, 1.008]</td>\n",
       "      <td>[2]</td>\n",
       "      <td>[2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C</td>\n",
       "      <td>0</td>\n",
       "      <td>[0]</td>\n",
       "      <td>[12.011]</td>\n",
       "      <td>[]</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CN</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0]</td>\n",
       "      <td>[12.011, 14.007]</td>\n",
       "      <td>[13]</td>\n",
       "      <td>[2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>CN</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0]</td>\n",
       "      <td>[12.011, 14.007]</td>\n",
       "      <td>[13]</td>\n",
       "      <td>[2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>CC</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0]</td>\n",
       "      <td>[12.011, 12.011]</td>\n",
       "      <td>[12]</td>\n",
       "      <td>[2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>[CH2:3]=[N+:1]([H:4])[H:2]</td>\n",
       "      <td>1</td>\n",
       "      <td>[1, 0, 0, 0]</td>\n",
       "      <td>[14.007, 1.008, 12.011, 1.008]</td>\n",
       "      <td>[13, 8, 8]</td>\n",
       "      <td>[4, 2, 2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>CCCC</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0, 0, 0]</td>\n",
       "      <td>[12.011, 12.011, 12.011, 12.011]</td>\n",
       "      <td>[12, 12, 12]</td>\n",
       "      <td>[2, 2, 2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>CO</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0]</td>\n",
       "      <td>[12.011, 15.999]</td>\n",
       "      <td>[14]</td>\n",
       "      <td>[2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>CC#N</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0, 0]</td>\n",
       "      <td>[12.011, 12.011, 14.007]</td>\n",
       "      <td>[12, 13]</td>\n",
       "      <td>[2, 6]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>C1NN1</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0, 0]</td>\n",
       "      <td>[12.011, 14.007, 14.007]</td>\n",
       "      <td>[13, 14, 13]</td>\n",
       "      <td>[2, 2, 2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>c1cc[n-]c1</td>\n",
       "      <td>-1</td>\n",
       "      <td>[0, 0, 0, -1, 0]</td>\n",
       "      <td>[12.011, 12.011, 12.011, 14.007, 12.011]</td>\n",
       "      <td>[12, 12, 13, 13, 12]</td>\n",
       "      <td>[3, 3, 3, 3, 3]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                        smiles  mol_y           atom_y1  \\\n",
       "0                       [H][H]      0            [0, 0]   \n",
       "1                            C      0               [0]   \n",
       "2                           CN      0            [0, 0]   \n",
       "3                           CN      0            [0, 0]   \n",
       "4                           CC      0            [0, 0]   \n",
       "5   [CH2:3]=[N+:1]([H:4])[H:2]      1      [1, 0, 0, 0]   \n",
       "6                         CCCC      0      [0, 0, 0, 0]   \n",
       "7                           CO      0            [0, 0]   \n",
       "8                         CC#N      0         [0, 0, 0]   \n",
       "9                        C1NN1      0         [0, 0, 0]   \n",
       "10                  c1cc[n-]c1     -1  [0, 0, 0, -1, 0]   \n",
       "\n",
       "                                     atom_y2               bond_y1  \\\n",
       "0                             [1.008, 1.008]                   [2]   \n",
       "1                                   [12.011]                    []   \n",
       "2                           [12.011, 14.007]                  [13]   \n",
       "3                           [12.011, 14.007]                  [13]   \n",
       "4                           [12.011, 12.011]                  [12]   \n",
       "5             [14.007, 1.008, 12.011, 1.008]            [13, 8, 8]   \n",
       "6           [12.011, 12.011, 12.011, 12.011]          [12, 12, 12]   \n",
       "7                           [12.011, 15.999]                  [14]   \n",
       "8                   [12.011, 12.011, 14.007]              [12, 13]   \n",
       "9                   [12.011, 14.007, 14.007]          [13, 14, 13]   \n",
       "10  [12.011, 12.011, 12.011, 14.007, 12.011]  [12, 12, 13, 13, 12]   \n",
       "\n",
       "            bond_y2  \n",
       "0               [2]  \n",
       "1                []  \n",
       "2               [2]  \n",
       "3               [2]  \n",
       "4               [2]  \n",
       "5         [4, 2, 2]  \n",
       "6         [2, 2, 2]  \n",
       "7               [2]  \n",
       "8            [2, 6]  \n",
       "9         [2, 2, 2]  \n",
       "10  [3, 3, 3, 3, 3]  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_input = pd.read_csv(data_dir / \"constrained_regression.csv\")\n",
    "df_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "64e34f4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "columns = [\"smiles\", \"mol_y\", \"atom_y1\", \"atom_y2\", \"bond_y1\", \"bond_y2\"]\n",
    "smis = df_input.loc[:, columns[0]].values\n",
    "mol_ys = df_input.loc[:, columns[1:2]].values\n",
    "atoms_ys = df_input.loc[:, columns[2:4]].values\n",
    "bonds_ys = df_input.loc[:, columns[4:6]].values\n",
    "\n",
    "atoms_ys = [\n",
    "    np.array([ast.literal_eval(atom_y) for atom_y in atom_ys], dtype=float).T\n",
    "    for atom_ys in atoms_ys\n",
    "]\n",
    "bonds_ys = [\n",
    "    np.array([ast.literal_eval(bond_y) for bond_y in bond_ys], dtype=float).T\n",
    "    for bond_ys in bonds_ys\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4522ec1e",
   "metadata": {},
   "source": [
    "### Load constraints\n",
    "\n",
    "Not all atom and bond predictions need to be constrained. Here both atom predictions are constrained and only one of the bond predictions is constrained."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cd66d995",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>atom_y1_constraint</th>\n",
       "      <th>atom_y2_constraint</th>\n",
       "      <th>bond_y2_constraint</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>2.016</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>12.011</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>26.018</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>26.018</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>24.022</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>28.034</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0</td>\n",
       "      <td>48.044</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0</td>\n",
       "      <td>28.010</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0</td>\n",
       "      <td>38.029</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0</td>\n",
       "      <td>40.025</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>-1</td>\n",
       "      <td>62.051</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    atom_y1_constraint  atom_y2_constraint  bond_y2_constraint\n",
       "0                    0               2.016                   2\n",
       "1                    0              12.011                   0\n",
       "2                    0              26.018                   2\n",
       "3                    0              26.018                   2\n",
       "4                    0              24.022                   2\n",
       "5                    1              28.034                   8\n",
       "6                    0              48.044                   6\n",
       "7                    0              28.010                   2\n",
       "8                    0              38.029                   8\n",
       "9                    0              40.025                   6\n",
       "10                  -1              62.051                  15"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_constraints = pd.read_csv(data_dir / \"constrained_regression_constraints.csv\")\n",
    "df_constraints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2a5798dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_mols = len(df_constraints)\n",
    "\n",
    "# A dictionary to map the atom and bond target columns to the corresponding constraint column\n",
    "constraints_cols_to_target_cols = {\n",
    "    \"atom_y1\": 0,\n",
    "    \"atom_y2\": 1,\n",
    "    \"bond_y2\": 2,\n",
    "}\n",
    "\n",
    "# Target columns without constraints have their constraints set to np.nan\n",
    "atom_constraint_cols = [\n",
    "    constraints_cols_to_target_cols.get(col)\n",
    "    for col in columns[2:4]\n",
    "]\n",
    "atom_constraints = np.hstack(\n",
    "    [\n",
    "        df_constraints.iloc[:, col].values.reshape(-1, 1)\n",
    "        if col is not None\n",
    "        else np.full((n_mols, 1), np.nan)\n",
    "        for col in atom_constraint_cols\n",
    "    ]\n",
    ")\n",
    "\n",
    "bond_constraint_cols = [\n",
    "    constraints_cols_to_target_cols.get(col)\n",
    "    for col in columns[4:6]\n",
    "]\n",
    "bond_constraints = np.hstack(\n",
    "    [\n",
    "        df_constraints.iloc[:, col].values.reshape(-1, 1)\n",
    "        if col is not None\n",
    "        else np.full((n_mols, 1), np.nan)\n",
    "        for col in bond_constraint_cols\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "049a2a33",
   "metadata": {},
   "outputs": [],
   "source": [
    "datapoints = [\n",
    "    data.MolAtomBondDatapoint.from_smi(\n",
    "        smi,\n",
    "        keep_h=True,\n",
    "        add_h=False,\n",
    "        reorder_atoms=True,\n",
    "        y=mol_ys[i],\n",
    "        atom_y=atoms_ys[i],\n",
    "        bond_y=bonds_ys[i],\n",
    "        atom_constraint=atom_constraints[i],\n",
    "        bond_constraint=bond_constraints[i],\n",
    "    )\n",
    "    for i, smi in enumerate(smis)\n",
    "]\n",
    "\n",
    "train_dataset = data.MolAtomBondDataset(datapoints)\n",
    "val_dataset = data.MolAtomBondDataset(datapoints)\n",
    "test_dataset = data.MolAtomBondDataset(datapoints)\n",
    "predict_dataset = data.MolAtomBondDataset(datapoints)\n",
    "\n",
    "# If the atom/bond targets are scaled, the corresponding constraints are also scaled automatically.\n",
    "atom_target_scaler = train_dataset.normalize_targets(\"atom\")\n",
    "val_dataset.normalize_targets(\"atom\", atom_target_scaler)\n",
    "atom_target_transform = nn.UnscaleTransform.from_standard_scaler(atom_target_scaler)\n",
    "\n",
    "train_dataloader = data.build_dataloader(train_dataset, shuffle=True)\n",
    "val_dataloader = data.build_dataloader(val_dataset, shuffle=False)\n",
    "test_dataloader = data.build_dataloader(test_dataset, shuffle=False)\n",
    "predict_dataloader = data.build_dataloader(predict_dataset, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38f801df",
   "metadata": {},
   "source": [
    "## Set up model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cede1449",
   "metadata": {},
   "outputs": [],
   "source": [
    "mp = nn.MABBondMessagePassing()\n",
    "agg = nn.NormAggregation()\n",
    "mol_predictor = nn.RegressionFFN(n_tasks=mol_ys.shape[1])\n",
    "atom_predictor = nn.RegressionFFN(n_tasks=atoms_ys[0].shape[1], output_transform=atom_target_transform)\n",
    "bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2), n_tasks=bonds_ys[0].shape[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3636a6a",
   "metadata": {},
   "source": [
    "Each atom/bond prediction for a constrained target is adjusted so they sum to the constraint. The amount each individual prediction is adjusted is determined from the node/edge fingerprints using a separate feed forward network. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d20acd32",
   "metadata": {},
   "outputs": [],
   "source": [
    "atom_constrainer = nn.ConstrainerFFN(n_constraints=(~np.isnan(atom_constraints[0])).sum())\n",
    "bond_constrainer = nn.ConstrainerFFN(n_constraints=(~np.isnan(bond_constraints[0])).sum(), fp_dim=600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ab1d3bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.MolAtomBondMPNN(\n",
    "    message_passing=mp,\n",
    "    agg=agg,\n",
    "    mol_predictor=mol_predictor,\n",
    "    atom_predictor=atom_predictor,\n",
    "    bond_predictor=bond_predictor,\n",
    "    atom_constrainer=atom_constrainer,\n",
    "    bond_constrainer=bond_constrainer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "85b9f9f0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MolAtomBondMPNN(\n",
       "  (message_passing): MABBondMessagePassing(\n",
       "    (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
       "    (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "    (W_vo): Linear(in_features=372, out_features=300, bias=True)\n",
       "    (W_eo): Linear(in_features=314, out_features=300, bias=True)\n",
       "    (dropout): Dropout(p=0.0, inplace=False)\n",
       "    (tau): ReLU()\n",
       "    (V_d_transform): Identity()\n",
       "    (E_d_transform): Identity()\n",
       "    (graph_transform): Identity()\n",
       "  )\n",
       "  (agg): NormAggregation()\n",
       "  (mol_predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=300, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): Identity()\n",
       "  )\n",
       "  (atom_predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=300, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=2, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0, 1.0]])\n",
       "    (output_transform): UnscaleTransform()\n",
       "  )\n",
       "  (atom_constrainer): ConstrainerFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=300, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=2, bias=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (bond_predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=600, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=2, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0, 1.0]])\n",
       "    (output_transform): Identity()\n",
       "  )\n",
       "  (bond_constrainer): ConstrainerFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=600, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (bns): ModuleList(\n",
       "    (0-2): 3 x Identity()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metricss): ModuleList(\n",
       "    (0): ModuleList(\n",
       "      (0-1): 2 x MSE(task_weights=[[1.0]])\n",
       "    )\n",
       "    (1-2): 2 x ModuleList(\n",
       "      (0): MSE(task_weights=[[1.0]])\n",
       "      (1): MSE(task_weights=[[1.0, 1.0]])\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6348d372",
   "metadata": {},
   "source": [
    "## The atom and bond predictions obey the constraints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "90318cbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(predict_dataloader))\n",
    "bmg, V_d, E_d, X_d, *_, constraints = batch\n",
    "with torch.no_grad():\n",
    "    mol_preds, atom_preds_tensor, bond_preds_tensor = model(bmg, V_d, E_d, X_d, constraints)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ef5d0a1f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.00000000e+00 -3.24249267e-08]\n",
      " [ 0.00000000e+00  3.20434570e-07]\n",
      " [ 3.72529030e-09 -1.55639648e-06]\n",
      " [ 3.72529030e-09 -1.55639648e-06]\n",
      " [ 0.00000000e+00  6.40869139e-07]\n",
      " [ 0.00000000e+00 -3.96728517e-07]\n",
      " [-1.49011612e-08  1.28173828e-06]\n",
      " [ 3.72529030e-09  1.67846680e-06]\n",
      " [ 3.72529030e-09  6.71386722e-07]\n",
      " [ 3.72529030e-09  2.28881836e-06]\n",
      " [ 0.00000000e+00  1.31225586e-06]]\n"
     ]
    }
   ],
   "source": [
    "atoms_per_mol = [mol.GetNumAtoms() for mol in predict_dataset.mols]\n",
    "atom_preds = torch.split(atom_preds_tensor, atoms_per_mol)\n",
    "errors = predict_dataset.atom_constraints - torch.vstack([p.sum(dim=0) for p in atom_preds]).numpy()\n",
    "print(errors)\n",
    "assert np.all(np.isclose(errors[~np.isnan(errors)], 0.0, atol=1e-5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "318dbbb3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[           nan 0.00000000e+00]\n",
      " [           nan 0.00000000e+00]\n",
      " [           nan 1.19209290e-07]\n",
      " [           nan 1.19209290e-07]\n",
      " [           nan 0.00000000e+00]\n",
      " [           nan 4.76837158e-07]\n",
      " [           nan 0.00000000e+00]\n",
      " [           nan 0.00000000e+00]\n",
      " [           nan 0.00000000e+00]\n",
      " [           nan 0.00000000e+00]\n",
      " [           nan 9.53674316e-07]]\n"
     ]
    }
   ],
   "source": [
    "bonds_per_mol = [mol.GetNumBonds() for mol in predict_dataset.mols]\n",
    "bond_preds = torch.split(bond_preds_tensor, bonds_per_mol)\n",
    "errors = predict_dataset.bond_constraints - torch.vstack([p.sum(dim=0) for p in bond_preds]).numpy()\n",
    "print(errors)\n",
    "assert np.all(np.isclose(errors[~np.isnan(errors)], 0.0, atol=1e-5))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2175890",
   "metadata": {},
   "source": [
    "## Fit the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b3051d28",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n",
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(\n",
    "    logger=False,\n",
    "    enable_progress_bar=True,\n",
    "    accelerator=\"auto\",\n",
    "    devices=1,\n",
    "    max_epochs=20,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d5f4f787",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
      "\n",
      "  | Name             | Type                  | Params | Mode \n",
      "-------------------------------------------------------------------\n",
      "0 | message_passing  | MABBondMessagePassing | 322 K  | train\n",
      "1 | agg              | NormAggregation       | 0      | train\n",
      "2 | mol_predictor    | RegressionFFN         | 90.6 K | train\n",
      "3 | atom_predictor   | RegressionFFN         | 90.9 K | train\n",
      "4 | atom_constrainer | ConstrainerFFN        | 90.9 K | train\n",
      "5 | bond_predictor   | RegressionFFN         | 180 K  | train\n",
      "6 | bond_constrainer | ConstrainerFFN        | 180 K  | train\n",
      "7 | bns              | ModuleList            | 0      | train\n",
      "8 | X_d_transform    | Identity              | 0      | train\n",
      "9 | metricss         | ModuleList            | 0      | train\n",
      "-------------------------------------------------------------------\n",
      "956 K     Trainable params\n",
      "0         Non-trainable params\n",
      "956 K     Total params\n",
      "3.824     Total estimated model params size (MB)\n",
      "72        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|██████████| 1/1 [00:00<00:00,  3.89it/s, train_loss_step=82.10, val_loss=81.20, train_loss_epoch=82.10]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|██████████| 1/1 [00:00<00:00,  3.26it/s, train_loss_step=82.10, val_loss=81.20, train_loss_epoch=82.10]\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(model, train_dataloader, val_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d177eb79",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:149: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.\n",
      "Restoring states from the checkpoint path at /home/knathan/chemprop/examples/checkpoints/epoch=19-step=20.ckpt\n",
      "Loaded model weights from the checkpoint at /home/knathan/chemprop/examples/checkpoints/epoch=19-step=20.ckpt\n",
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 20.36it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       atom_test/mse       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     3.249805212020874     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       bond_test/mse       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    15.424737930297852     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       mol_test/mse        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.16646192967891693    </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      atom_test/mse      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    3.249805212020874    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      bond_test/mse      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   15.424737930297852    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      mol_test/mse       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.16646192967891693   \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results = trainer.test(dataloaders=test_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "13ce67b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 25.98it/s]\n"
     ]
    }
   ],
   "source": [
    "predss = trainer.predict(model, predict_dataloader)\n",
    "mol_preds, atom_preds, bond_preds = (torch.concat(tensors) for tensors in zip(*predss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "daccce8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "atoms_per_mol = [mol.GetNumAtoms() for mol in predict_dataset.mols]\n",
    "bonds_per_mol = [mol.GetNumBonds() for mol in predict_dataset.mols]\n",
    "\n",
    "atom_preds = torch.split(atom_preds, atoms_per_mol)\n",
    "bond_preds = torch.split(bond_preds, bonds_per_mol)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
